package io.sundial.console.job;

import io.sundial.coordination.Coordinator;
import io.sundial.coordination.CoordinatorException;
import io.sundial.coordination.node.Node;
import io.sundial.coordination.tree.TreeEvent;
import io.sundial.coordination.tree.TreeWatcher;
import io.sundial.core.context.Context;
import io.sundial.core.lifecycle.Lifecycle;
import io.sundial.core.lifecycle.exception.InitializingException;
import io.sundial.job.JobDefinition;
import io.sundial.job.JobKey;
import io.sundial.protocol.Protocol;
import io.sundial.protocol.exception.MarshallingException;
import io.sundial.protocol.exception.UnmarshallingException;
import io.sundial.util.ResKit;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.BeanInitializationException;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.util.StringUtils;
import org.springframework.web.bind.annotation.*;

import javax.annotation.Resource;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;

/**
 * 作业控制器
 *
 * @author Payne 646742615@qq.com
 * 2019/1/7 10:44
 * @tag 作业
 */
@CrossOrigin
@RestController
@RequestMapping("/api/job")
public class JobController implements TreeWatcher, ApplicationContextAware {
    private ConcurrentMap<String, ConcurrentMap<String, JobCache>> cache = new ConcurrentHashMap<>();

    @Resource
    private Coordinator coordinator;

    @Resource
    private Context context;

    private Protocol defaultProtocol;
    private Map<String, Protocol> supportProtocols;

    /**
     * 作业下拉框搜索
     *
     * @param param 作业下拉框搜索参数
     * @return 作业下拉框搜索结果
     * @summary 作业下拉框搜索
     */
    @GetMapping
    @ResponseBody
    public JobSearchResult search(JobSearchParam param) {
        JobSearchResult result = new JobSearchResult();

        List<JobKey> jobKeys = new ArrayList<>();
        String keyword = param.getKeyword();
        for (Map.Entry<String, ConcurrentMap<String, JobCache>> gEntry : cache.entrySet()) {
            String jobGroup = gEntry.getKey();
            ConcurrentMap<String, JobCache> jobGroups = gEntry.getValue();
            for (Map.Entry<String, JobCache> jEntry : jobGroups.entrySet()) {
                String jobName = jEntry.getKey();
                JobCache jobCache = jEntry.getValue();
                if (jobCache.getJobExecutors().isEmpty()) {
                    continue;
                }
                if (StringUtils.isEmpty(keyword) || jobGroup.toLowerCase().contains(keyword.toLowerCase()) || jobName.toLowerCase().contains(keyword.toLowerCase())) {
                    jobKeys.add(new JobKey(jobName, jobGroup));
                }
            }
        }

        result.setJobKeys(jobKeys);

        return result;
    }

    /**
     * 获取作业定义
     * 参数来源于作业下拉框搜索结果中的 jobGroup 和 jobName
     * <p style="color: red;"><b>注意：这里路径需要 URLEncode 因为有可能是中文！</b></p>
     *
     * @param param 作业定义获取参数
     * @return 作业定义获取结果
     * @summary 获取作业定义
     */
    @GetMapping("/{jobGroup}/{jobName}")
    @ResponseBody
    public JobObtainResult obtain(@PathVariable("jobGroup") String jobGroup, @PathVariable("jobName") String jobName, JobObtainParam param) {
        JobObtainResult result = new JobObtainResult();

        ConcurrentMap<String, JobCache> jobGroups = cache.get(jobGroup);
        if (jobGroups == null) {
            return result.error(404, "作业不存在");
        }

        JobCache jobCache = jobGroups.get(jobName);
        if (jobCache == null) {
            return result.error(404, "作业不存在");
        }

        JobDefinition jobDefinition = jobCache.getJobDefinition().get();
        result.setJobDefinition(jobDefinition);

        return result;
    }

    @Override
    public void onWatched(TreeEvent event) throws Exception {
        TreeEvent.Type type = event.getType();
        switch (type) {
            case NODE_CREATED:
            case NODE_UPDATED: {
                Node node = event.getNode();
                String path = node.getPath();
                String[] paths = path.split("/");

                if (paths.length == 4) {
                    String jobGroup = paths[2];
                    String jobName = paths[3];
                    byte[] data = node.getData();

                    JobDefinition jobDefinition = unmarshal(data);
                    ConcurrentMap<String, JobCache> newGroups = new ConcurrentHashMap<>();
                    ConcurrentMap<String, JobCache> oldGroups = cache.putIfAbsent(jobGroup, newGroups);
                    if (oldGroups != null) newGroups = oldGroups;

                    JobCache newCache = new JobCache(jobDefinition);
                    JobCache oldCache = newGroups.putIfAbsent(jobName, newCache);
                    if (oldCache != null) newCache = oldCache;
                    newCache.getJobDefinition().set(jobDefinition);
                }

                if (paths.length == 5) {
                    String jobGroup = paths[2];
                    String jobName = paths[3];
                    String executor = paths[4];
                    byte[] data = node.getData();

                    JobDefinition jobDefinition = unmarshal(data);
                    ConcurrentMap<String, JobCache> jobGroups = cache.get(jobGroup);
                    JobCache jobCache = jobGroups != null ? jobGroups.get(jobName) : null;
                    if (jobCache == null) return;

                    jobCache.getJobExecutors().put(executor, jobDefinition);
                }
            }
            break;
            case NODE_REMOVED:
                Node node = event.getNode();
                String path = node.getPath();
                String[] paths = path.split("/");
                if (paths.length == 4) {
                    String jobGroup = paths[2];
                    String jobName = paths[3];
                    ConcurrentMap<String, JobCache> jobGroups = cache.get(jobGroup);
                    if (jobGroups == null) return;
                    jobGroups.remove(jobName);
                }
                if (paths.length == 5) {
                    String jobGroup = paths[2];
                    String jobName = paths[3];
                    String executor = paths[4];
                    ConcurrentMap<String, JobCache> jobGroups = cache.get(jobGroup);
                    if (jobGroups == null) return;
                    JobCache jobCache = jobGroups.get(jobName);
                    if (jobCache == null) return;
                    jobCache.getJobExecutors().remove(executor);
                }
                break;
        }
    }

    protected byte[] marshall(Object obj) throws MarshallingException {
        return defaultProtocol.marshall(obj);
    }

    protected JobDefinition unmarshal(byte[] data) throws UnmarshallingException {
        try (ByteArrayInputStream bis = new ByteArrayInputStream(data)) {
            String head = ResKit.readln(bis);
            byte[] body = new byte[bis.available()];
            //noinspection ResultOfMethodCallIgnored
            bis.read(body);
            Protocol protocol = head == null ? null : supportProtocols.get(head.toUpperCase());
            return protocol != null ? protocol.unmarshal(body, JobDefinition.class) : defaultProtocol.unmarshal(body, JobDefinition.class);
        } catch (IOException e) {
            throw new UnmarshallingException(e);
        }
    }

    @Override
    public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
        // 如果没有设置就从上下文中找到所有的。

        if (supportProtocols == null || supportProtocols.isEmpty()) {
            Map<String, Protocol> map = applicationContext.getBeansOfType(Protocol.class);
            Collection<Protocol> collection = map.values();
            Map<String, Protocol> protocols = new LinkedHashMap<>();
            for (Protocol protocol : collection) {
                String name = protocol.getName();
                if (name == null || name.trim().isEmpty()) {
                    throw new BeanInitializationException("protocol " + protocol.getClass() + " must declare a not null or blank name");
                }
                String version = protocol.getVersion();
                if (version == null || version.trim().isEmpty()) {
                    throw new BeanInitializationException("protocol " + protocol.getClass() + " must declare a not null or blank version");
                }
                if (protocol instanceof Lifecycle) {
                    try {
                        ((Lifecycle) protocol).initialize(context);
                    } catch (InitializingException e) {
                        throw new BeanInitializationException(e.getMessage(), e);
                    }
                }
                protocols.put(name.trim().toUpperCase() + "/" + version.trim().toUpperCase(), protocol);
            }
            supportProtocols = Collections.unmodifiableMap(protocols);
        }

        // 如果没配置缺省协议，就选择第一个作为默认的协议。
        if (defaultProtocol == null) {
            if (supportProtocols.isEmpty()) {
                throw new BeanInitializationException("please config at least one protocol");
            } else {
                defaultProtocol = supportProtocols.values().iterator().next();
            }
        }

        try {
            coordinator.watch("/jobs", this);
        } catch (CoordinatorException e) {
            throw new BeanInitializationException(e.getMessage(), e);
        }
    }
}
